apiVersion: v1
kind: ServiceAccount
metadata:
  name: gpu-labeler-sa
  namespace: kube-system
  labels:
    parent: skypilot
    job: sky-gpu-labeler

---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
  name: node-patcher-role
  namespace: kube-system
  labels:
    parent: skypilot
    job: sky-gpu-labeler
rules:
- apiGroups: [""]
  resources: ["nodes"]
  verbs: ["patch"]

---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
  name: node-patcher-rolebinding
  labels:
    parent: skypilot
    job: sky-gpu-labeler
roleRef:
  apiGroup: rbac.authorization.k8s.io
  kind: ClusterRole
  name: node-patcher-role
subjects:
- kind: ServiceAccount
  name: gpu-labeler-sa
  namespace: kube-system

---
apiVersion: v1
kind: ConfigMap
metadata:
  name: gpu-labeler-script
  namespace: kube-system
  labels:
    parent: skypilot
    job: sky-gpu-labeler
data:
  label_gpus.py: |
    #!/usr/bin/env python3
    import os
    import subprocess
    from typing import Optional
    
    from kubernetes import client
    from kubernetes import config
    
    canonical_gpu_names = [
        'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4', 'V100', 
        'A10', 'P100', 'P40', 'P4', 'L4'
    ]
    
    
    def get_gpu_name() -> Optional[str]:
        try:
            result = subprocess.run(
                ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader,nounits'],
                stdout=subprocess.PIPE)
            gpu_name = result.stdout.decode('utf-8').strip()
            # In the case of multi-gpu nodes, we assume the node is homogenous and 
            # just use the first GPU name.
            gpu_name = gpu_name.split('\n')[0]
            return gpu_name.lower()
        except Exception as e:
            print(f'Error getting GPU name: {e}')
            return None
    
    
    def label_node(gpu_name: str) -> None:
        try:
            config.load_incluster_config()  # Load in-cluster configuration
            v1 = client.CoreV1Api()
    
            # Fetch the current node's name from the environment variable
            node_name = os.environ.get('MY_NODE_NAME')
            if not node_name:
                raise ValueError('Failed to get node name from environment')
    
            # Label the node with the GPU name
            body = {'metadata': {'labels': {'skypilot.co/accelerator': gpu_name}}}
            v1.patch_node(node_name, body)
    
            print(f'Labeled node {node_name} with GPU {gpu_name}')
    
        except Exception as e:
            print(f'Error labeling node: {e}')
    
    
    def main():
        gpu_name = get_gpu_name()
        if gpu_name is not None:
            labelled = False
            for canonical_name in canonical_gpu_names:
                if canonical_name.lower() in gpu_name.lower():
                    label_node(canonical_name.lower())
                    labelled = True
                    break
            if not labelled:
                # If we didn't find a canonical name:
                # 1. remove 'NVIDIA ' if present (e.g., 'NVIDIA RTX A6000' -> 'RTX A6000')
                # 2. remove 'GeForce ' if present (e.g., 'NVIDIA GeForce RTX 3070' -> 'RTX 3070')
                # 3. replace 'RTX ' with 'RTX' (without spaces) (e.g., 'RTX 6000' -> 'RTX6000')
                # 4. replace any other spaces with dashes (e.g. 'RTX 2080 Ti' -> 'RTX2080-Ti')
                gpu_name = gpu_name.lower().replace('nvidia ', '').replace('geforce ', '').replace('rtx ', 'rtx').replace(' ', '-')
                gpu_label = gpu_name
                label_node(gpu_label)
                labelled = True
        else:
            print('No GPU detected. Try running nvidia-smi in the container.')
    
    
    if __name__ == '__main__':
        main()
